{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e57c5e2c-04f8-40b7-9b47-e5e05505cb2c",
   "metadata": {},
   "source": [
    "# Training a Diffusion Model with Weights and Biases (W&B)\n",
    "\n",
    "<!--- @wandbcode{dlai_02} -->\n",
    "\n",
    "In this notebooks we will instrument the training of a diffusion model with W&B. We will use the Lab3 notebook from the [\"How diffusion models work\"](https://www.deeplearning.ai/short-courses/how-diffusion-models-work/) course. \n",
    "We will add:\n",
    "- Logging of the training loss and metrics\n",
    "- Sampling from the model during training and uploading the samples to W&B\n",
    "- Saving the model checkpoints to W&B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4a34666-2281-49e3-8574-93d57c72771b",
   "metadata": {
    "height": 182,
    "tags": []
   },
   "outputs": [],
   "source": [
    "from types import SimpleNamespace\n",
    "from pathlib import Path\n",
    "from tqdm.notebook import tqdm\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "from utilities import *\n",
    "\n",
    "import wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "904d68fe-7435-48a3-b8af-c4be8675311c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "wandb.login(anonymous=\"allow\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02e2b5b2-82e4-4535-aa98-34ae64a808e8",
   "metadata": {},
   "source": [
    "## Setting Things Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4918eda7-6d6b-4f9f-8650-c347ed4a5d1c",
   "metadata": {
    "height": 437,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# we are storing the parameters to be logged to wandb\n",
    "DATA_DIR = Path('./data/')\n",
    "SAVE_DIR = Path('./data/weights/')\n",
    "SAVE_DIR.mkdir(exist_ok=True, parents=True)\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "config = SimpleNamespace(\n",
    "    # hyperparameters\n",
    "    num_samples = 30,\n",
    "\n",
    "    # diffusion hyperparameters\n",
    "    timesteps = 500,\n",
    "    beta1 = 1e-4,\n",
    "    beta2 = 0.02,\n",
    "\n",
    "    # network hyperparameters\n",
    "    n_feat = 64, # 64 hidden dimension feature\n",
    "    n_cfeat = 5, # context vector is of size 5\n",
    "    height = 16, # 16x16 image\n",
    "    \n",
    "    # training hyperparameters\n",
    "    batch_size = 100,\n",
    "    n_epoch = 32,\n",
    "    lrate = 1e-3,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ed92a7b-b6a3-4c0c-a35d-154ec26ed923",
   "metadata": {},
   "source": [
    "### Setup DDPM noise scheduler and sampler (same as in the Diffusion course). \n",
    "- perturb_input: Adds noise to the input image at the corresponding timestep on the schedule\n",
    "- sample_ddpm_context: Generate images using the DDPM sampler, we will use this function during training to sample from the model regularly and see how our training is progressing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ba81b76-6521-4c7c-80bd-bacde0361a34",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "# setup ddpm sampler functions\n",
    "perturb_input, sample_ddpm_context = setup_ddpm(config.beta1, \n",
    "                                                config.beta2, \n",
    "                                                config.timesteps, \n",
    "                                                DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c83bd768-f709-410a-8062-703bde7997d8",
   "metadata": {
    "height": 97,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# construct model\n",
    "nn_model = ContextUnet(in_channels=3, \n",
    "                       n_feat=config.n_feat, \n",
    "                       n_cfeat=config.n_cfeat, \n",
    "                       height=config.height).to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf98a114-f7aa-4cbd-b08c-d56ad628da21",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# load dataset and construct optimizer\n",
    "dataset = CustomDataset.from_np(path=DATA_DIR)\n",
    "dataloader = DataLoader(dataset, \n",
    "                        batch_size=config.batch_size, \n",
    "                        shuffle=True)\n",
    "optim = torch.optim.Adam(nn_model.parameters(), lr=config.lrate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdccd6e0-850a-41ed-89e7-db629f838770",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2338bec6-319c-4603-8ae6-0e1fcbdd3a4e",
   "metadata": {},
   "source": [
    "We choose a fixed context vector with 6 samples of each class to guide our diffusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56bfcd32-1a9c-4d0e-8237-77da217f41ae",
   "metadata": {
    "height": 233,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Noise vector\n",
    "# x_T ~ N(0, 1), sample initial noise\n",
    "noises = torch.randn(config.num_samples, 3, \n",
    "                     config.height, config.height).to(DEVICE)  \n",
    "\n",
    "# A fixed context vector to sample from\n",
    "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0,   # hero\n",
    "                                     1,1,1,1,1,1,   # non-hero\n",
    "                                     2,2,2,2,2,2,   # food\n",
    "                                     3,3,3,3,3,3,   # spell\n",
    "                                     4,4,4,4,4,4]), # side-facing \n",
    "                       5).to(DEVICE).float()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e854b7c7-fa0d-4413-8642-f824449d6763",
   "metadata": {},
   "source": [
    "The following training cell takes very long to run on CPU, we have already trained the model for you on a GPU equipped machine.\n",
    "\n",
    "### You can visit the result of this >> [training here](https://wandb.ai/dlai-course/dlai_sprite_diffusion/runs/pzs3gsyo) <<"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c87ca8f-2c09-487f-a8bc-7030c2b76492",
   "metadata": {
    "height": 947,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# create a wandb run\n",
    "run = wandb.init(project=\"dlai_sprite_diffusion\", \n",
    "                 job_type=\"train\", \n",
    "                 config=config)\n",
    "\n",
    "# we pass the config back from W&B\n",
    "config = wandb.config\n",
    "\n",
    "for ep in tqdm(range(config.n_epoch), leave=True, total=config.n_epoch):\n",
    "    # set into train mode\n",
    "    nn_model.train()\n",
    "    optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n",
    "    \n",
    "    pbar = tqdm(dataloader, leave=False)\n",
    "    for x, c in pbar:   # x: images  c: context\n",
    "        optim.zero_grad()\n",
    "        x = x.to(DEVICE)\n",
    "        c = c.to(DEVICE)   \n",
    "        context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.8).to(DEVICE)\n",
    "        c = c * context_mask.unsqueeze(-1)        \n",
    "        noise = torch.randn_like(x)\n",
    "        t = torch.randint(1, config.timesteps + 1, (x.shape[0],)).to(DEVICE) \n",
    "        x_pert = perturb_input(x, t, noise)      \n",
    "        pred_noise = nn_model(x_pert, t / config.timesteps, c=c)      \n",
    "        loss = F.mse_loss(pred_noise, noise)\n",
    "        loss.backward()    \n",
    "        optim.step()\n",
    "\n",
    "        wandb.log({\"loss\": loss.item(),\n",
    "                   \"lr\": optim.param_groups[0]['lr'],\n",
    "                   \"epoch\": ep})\n",
    "\n",
    "\n",
    "    # save model periodically\n",
    "    if ep%4==0 or ep == int(config.n_epoch-1):\n",
    "        nn_model.eval()\n",
    "        ckpt_file = SAVE_DIR/f\"context_model.pth\"\n",
    "        torch.save(nn_model.state_dict(), ckpt_file)\n",
    "\n",
    "        artifact_name = f\"{wandb.run.id}_context_model\"\n",
    "        at = wandb.Artifact(artifact_name, type=\"model\")\n",
    "        at.add_file(ckpt_file)\n",
    "        wandb.log_artifact(at, aliases=[f\"epoch_{ep}\"])\n",
    "\n",
    "        samples, _ = sample_ddpm_context(nn_model, \n",
    "                                         noises, \n",
    "                                         ctx_vector[:config.num_samples])\n",
    "        wandb.log({\n",
    "            \"train_samples\": [\n",
    "                wandb.Image(img) for img in samples.split(1)\n",
    "            ]})\n",
    "        \n",
    "# finish W&B run\n",
    "wandb.finish()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
